Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjust forward stage2 to Core.Compiler changes #295

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

Keno
Copy link
Collaborator

@Keno Keno commented Oct 2, 2024

Only what is necessary for Cedar right now. Ordinary stage 2 reverse mode will need similar changes at a later point.

@Keno Keno requested a review from aviatesk October 2, 2024 22:54

local frule_call::Future{CallMeta}
local result::Future{Union{CallMeta, Nothing}} = Future{Union{CallMeta, Nothing}}()
function make_progress(_, sv)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vtjnash, please confirm that this is the intended way to use this.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You seem to have closure captured interp instead of using the argument? The interp struct is commonly quite large, so that can increase memory usage quite a bit

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The argument is the wrong interp

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this way of defining a make_progress seems fine. There isn't really one right answer about how to code this, so base itself already uses probably 3 or 4 different patterns, depending on what kept the original code control flow seemed least distorted. I hadn't used the @isdefined trick, but it is essentially equivalent to the nextstate pattern I'd used for manual stackless state machine conversion

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think capturing the interp will give you the behavior you want. I think you might need to mutate sv instead of re-using sv with different interp

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not re-using sv with a different interp. sv here is an IRInterpretationState, which doesn't have an interp argument, so when the callback later gets scheduled, there's just some random interp in there.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sv claims to be a AbsIntState here? For IRInterpretationState currently it passes the interp here that was originally used to construct the IRInterpretationState, since everything is on the stack there and doesn't handle recursion

I think the behavior here is also probably fine, but that no other callback will be using the right interp, since none other are expecting the interp to be different from the one used to allocate the state object

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and doesn't handle recursion

Doesn't handle recursion in Base ;).

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair, I know almost nothing about this code, so I am reviewing without really knowing how this integrates. The current implementation in Base would potentially break here: https://github.com/JuliaLang/julia/blob/be401635fe02b28ce994e2e3cae0733d101f8927/base/compiler/ssair/irinterp.jl#L154
since it was not tracking if the return type changed to reschedule this instruction if it became part of cycle (I believe it should detect and @assert though if that attempts to happen)

Keno added a commit to CedarEDA/DAECompiler.jl that referenced this pull request Oct 3, 2024
Keno added a commit to CedarEDA/DAECompiler.jl that referenced this pull request Oct 3, 2024
Only what is necessary for Cedar right now. Ordinary stage 2 reverse
mode will need similar changes at a later point.
Copy link
Member

@aviatesk aviatesk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

nargs = length(arginfo.argtypes)-1
frule_preargtypes = Any[Const(ChainRulesCore.frule), Tuple{Nothing,Vararg{Any,nargs}}]
frule_argtypes = append!(frule_preargtypes, arginfo.argtypes)
frule_atype = CC.argtypes_to_type(frule_argtypes)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Be careful not to closure capture any types, as your performance may suffer quite badly, but still just fast enough you won't notice (e.g. the sysimage could build still when I missed one of these cases, it just took several times longer)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So put it in a Ref{Any}?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, or a Core.Box equivalently. I'd found several places where we had a MethodMatch object that was needed anyways, so that also happened to work sometimes

local result::Future{Union{CallMeta, Nothing}} = Future{Union{CallMeta, Nothing}}()
function make_progress(_, sv)
if isa(primal_call[].info, UnionSplitApplyCallInfo)
result[] = nothing
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type appears to be wrong here. The intended behavior appears to be returning result[] = primal_call[] in this case (

r = fwd_abstract_call_gf_by_type(interp, f, arginfo, si, sv, ret)
if r !== nothing
return r
end
)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's the function that calls this. Potentially it should be refactored to just do that, but I just wanted to only make the refactoring change.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type is required to be Future{CallMeta} though, or the caller's caller will be unhappy

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The call-site there isn't updated yet. This function is called directly from DAECompiler and I adjusted the call-site there to work with Future{Union{Nothing, CallMeta}}

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is fine if it still branches on r!==nothing there, it will just be dead code now, as it appears you you must handle that case here now, instead of being able to handle it there

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants